# -*- coding: utf-8 -*-
"""
Created on Wed May 12 15:56:45 2021

@author: SonHyungmok
"""
import numpy as np
from scipy.optimize import curve_fit
from scipy.integrate import odeint

from scipy.optimize import fsolve

sat_correct = 0.075 * 0
magnification_correct = 0.07 * 0
width_base = 662.* (1 - sat_correct) * (1 + magnification_correct)

def axial_gaussian_width_base():
    return width_base

def latticeFactor(naNum):
    if naNum < 440e3:
        width_from_num = width_base  ## 7.5% is from the saturation of the in-trap imaging for lattice factor meausurement
#         width_from_num = width_base. 
    else:
#         width_from_num =  (-2E-4 * naNum + 742.95) * (1 - sat_correct)
        width_from_num =  (-2E-4 * naNum + 742.95)
       
    return width_from_num *1e-6 * np.sqrt(8/np.pi)/(0.5 * 1596e-9)

data_B = np.array([978.77912, 976.2332, 973.3532, 975.2732, 977.6732, 977.1932])

data_width = np.array([775.25, 718.4285714, 663.8166667,687.0333333,760.5714286, 739.7]) 
data_width = np.array([775.25, 718.4285714, width_base,687.0333333,760.5714286, 739.7]) ## replace with the width_base @ 3V

data_B = np.array([978.77912, 976.2332, 975.2732, 977.6732, 977.1932])
data_width = np.array([775.25, 718.4285714, 687.0333333,760.5714286, 739.7]) ## replace with the width_base @ 3V

cubic = lambda B, a, b, c, d: a*B**3 + b*B**2 + c*B + d
quad = lambda B, a, b, c: a*B**2 + b*B + c
# popt, pcov = curve_fit(cubic, data_B, data_width)
popt, pcov = curve_fit(quad, data_B, data_width)
print(popt)

B_cut = 974.
B_peak = max(data_B)
#B_list = np.linspace(min(data_B)*0.995, max(data_B), 100)
#plt.scatter(data_B, data_width)
#plt.plot(B_list, cubic(B_list, *popt))

def latticeFactor_bfield(Bfield):
    if Bfield >= B_cut and Bfield <= B_peak:
#         width_from_B = cubic(Bfield, *popt)
        width_from_B = quad(Bfield, *popt)
    else:
        width_from_B = width_base
  
    return width_from_B *1e-6 * np.sqrt(8/np.pi)/(0.5 * 1596e-9)

# def latticeFactor_bfield(naNum, Bfield):
#     if Bfield > B_cut and Bfield < B_peak:
#         width_from_B = cubic(Bfield, *popt)
#     else:
#         width_from_B = width_base
    
#     if naNum < 440e3:
#         width_from_num = width_base * (1 - 0.075) ## 7.5% is from the saturation of the in-trap imaging for lattice factor meausurement
# #         width_from_num = width_base. 
#     else:
#         width_from_num =  (-2E-4 * naNum + 742.95) * (1 - 0.075)
# #         width_from_num =  (-2E-4 * naNum + 742.95)
      
#     number_correct_ratio = width_from_num/width_base
#     width = number_correct_ratio * width_from_B

   
#     return width *1e-6 * np.sqrt(8/np.pi)/(0.5 * 1596e-9)